{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "da5d9bc0-95ab-45d4-9378-417628d86e35",
   "metadata": {},
   "source": [
    "## 4.7 生成文本"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48da5deb-6ee0-4b9b-8dd2-abed7ed65172",
   "metadata": {},
   "source": [
    "- 我们上面实现的GPT模型等LLMs（大型语言模型）被用来一次生成一个单词"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "caade12a-fe97-480f-939c-87d24044edff",
   "metadata": {},
   "source": [
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/16.webp\" width=\"100%\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7061524-a3bd-4803-ade6-2e3b7b79ac13",
   "metadata": {},
   "source": [
    "\n",
    "以下是对给定文本的中英文翻译：\n",
    "\n",
    "- 以下的`generate_text_simple`函数实现了贪婪解码，这是一种简单且快速的文本生成方法\n",
    "- 在贪婪解码中，每一步，模型都会选择概率最高的词（或标记）作为下一个输出（最高的对数值对应于最高的概率，因此我们甚至不必显式计算softmax函数）\n",
    "- 在下一章中，我们将实现一个更高级的`generate_text`函数\n",
    "- 下图展示了GPT模型在给定输入上下文时如何生成下一个词标记\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ee0f32c-c18c-445e-b294-a879de2aa187",
   "metadata": {},
   "source": [
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/17.webp\" width=\"100%\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "c9b428a9-8764-4b36-80cd-7d4e00595ba6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_text_simple(model, idx, max_new_tokens, context_size):\n",
    "    # idx is (batch, n_tokens) array of indices in the current context\n",
    "    for _ in range(max_new_tokens):\n",
    "        \n",
    "        # Crop current context if it exceeds the supported context size\n",
    "        # E.g., if LLM supports only 5 tokens, and the context size is 10\n",
    "        # then only the last 5 tokens are used as context\n",
    "        idx_cond = idx[:, -context_size:]\n",
    "        \n",
    "        # Get the predictions\n",
    "        with torch.no_grad():\n",
    "            logits = model(idx_cond)\n",
    "        \n",
    "        # Focus only on the last time step\n",
    "        # (batch, n_tokens, vocab_size) becomes (batch, vocab_size)\n",
    "        logits = logits[:, -1, :]  \n",
    "\n",
    "        # Apply softmax to get probabilities\n",
    "        probas = torch.softmax(logits, dim=-1)  # (batch, vocab_size)\n",
    "\n",
    "        # Get the idx of the vocab entry with the highest probability value\n",
    "        idx_next = torch.argmax(probas, dim=-1, keepdim=True)  # (batch, 1)\n",
    "\n",
    "        # Append sampled index to the running sequence\n",
    "        idx = torch.cat((idx, idx_next), dim=1)  # (batch, n_tokens+1)\n",
    "\n",
    "    return idx"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6515f2c1-3cc7-421c-8d58-cc2f563b7030",
   "metadata": {},
   "source": [
    "\n",
    "上面的`generate_text_simple`实现了一个迭代过程，其中它一次生成一个标记\n",
    "\n",
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/18.webp\" width=\"100%\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f682eac4-f9bd-438b-9dec-6b1cc7bc05ce",
   "metadata": {},
   "source": [
    "- 让我们准备一个输入示例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "3d7e3e94-df0f-4c0f-a6a1-423f500ac1d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoded: [15496, 11, 314, 716]\n",
      "encoded_tensor.shape: torch.Size([1, 4])\n"
     ]
    }
   ],
   "source": [
    "start_context = \"Hello, I am\"\n",
    "\n",
    "encoded = tokenizer.encode(start_context)\n",
    "print(\"encoded:\", encoded)\n",
    "\n",
    "encoded_tensor = torch.tensor(encoded).unsqueeze(0)\n",
    "print(\"encoded_tensor.shape:\", encoded_tensor.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "a72a9b60-de66-44cf-b2f9-1e638934ada4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output: tensor([[15496,    11,   314,   716, 27018, 24086, 47843, 30961, 42348,  7267]])\n",
      "Output length: 10\n"
     ]
    }
   ],
   "source": [
    "model.eval() # disable dropout\n",
    "\n",
    "out = generate_text_simple(\n",
    "    model=model,\n",
    "    idx=encoded_tensor, \n",
    "    max_new_tokens=6, \n",
    "    context_size=GPT_CONFIG_124M[\"context_length\"]\n",
    ")\n",
    "\n",
    "print(\"Output:\", out)\n",
    "print(\"Output length:\", len(out[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d131c00-1787-44ba-bec3-7c145497b2c3",
   "metadata": {},
   "source": [
    "- 移除批次维度并转换回文本："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "053d99f6-5710-4446-8d52-117fb34ea9f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hello, I am Featureiman Byeswickattribute argue\n"
     ]
    }
   ],
   "source": [
    "decoded_text = tokenizer.decode(out.squeeze(0).tolist())\n",
    "print(decoded_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a894003-51f6-4ccc-996f-3b9c7d5a1d70",
   "metadata": {},
   "source": [
    "\n",
    "- 注意模型未经过训练；因此上述输出文本是随机的\n",
    "- 我们将在下一章中训练模型\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a35278b6-9e5c-480f-83e5-011a1173648f",
   "metadata": {},
   "source": [
    "## 总结和要点\n",
    "\n",
    "- 请查看./gpt.py脚本，这是一个包含我们在本Jupyter笔记本中实现的GPT模型的独立脚本\n",
    "- 您可以在./exercise-solutions.ipynb中找到练习题的解答"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
